package org.hepeng.workx.mybatis.datasource;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.apache.ibatis.session.SqlSession;
import org.hepeng.workx.extension.XLoader;
import org.hepeng.workx.jdbc.DataSourceRoute;
import org.hepeng.workx.mybatis.binding.MapperRegistryGetMapperObjectProcessor;
import org.hepeng.workx.mybatis.executor.SQLExecuteContext;
import org.hepeng.workx.util.proxy.Invocation;
import org.hepeng.workx.util.proxy.Invoker;
import org.hepeng.workx.util.proxy.ProxyFactory;
import org.hepeng.workx.util.proxy.ProxyUtils;
import org.joor.Reflect;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author he peng
 */
public class DataSourceRouteObjectProcessor implements MapperRegistryGetMapperObjectProcessor {

    private static final Map<Method , DataSourceRoute> DATASOURCE_ROUTE_MAP = new ConcurrentHashMap<>();

    protected ProxyFactory proxyFactory = XLoader.getXLoader(ProxyFactory.class).getX();

    @Override
    public GetMapperPreProcessObject preProcess(GetMapperPreProcessObject preProcessObject) {
        return preProcessObject;
    }

    @Override
    public GetMapperPostProcessObject postProcess(GetMapperPostProcessObject postProcessObject) {
        Class type = postProcessObject.getType();
        SqlSession sqlSession = postProcessObject.getSqlSession();
        Object mapper = postProcessObject.getMapper();

        cacheDataSourceRoute(type);
        InvocationHandler mapperProxy = Proxy.getInvocationHandler(mapper);
        Reflect mapperProxyReflect = Reflect.on(mapperProxy);

        List<Class<?>> constructorArgTypes = new ArrayList<>();
        constructorArgTypes.add(SqlSession.class);
        constructorArgTypes.add(Class.class);
        constructorArgTypes.add(Map.class);

        List<Object> constructorArgs = new ArrayList<>();
        constructorArgs.add(sqlSession);
        constructorArgs.add(type);
        constructorArgs.add(mapperProxyReflect.get("methodCache"));

        List<Class<?>> interfaces = new ArrayList<>();
        interfaces.add(type);

        Invoker invoker = new Invoker() {
            @Override
            public Object invoke(Invocation invocation) throws Throwable {
                Object result;
                try {
                    if (ProxyUtils.isProxy(invocation.getProxy().getClass())) {
                        DataSourceRoute dataSourceRoute = DATASOURCE_ROUTE_MAP.get(invocation.getArgs()[1]);
                        SQLExecuteContext context = SQLExecuteContext.getContext();
                        context.set("dataSourceRoute", dataSourceRoute);
                    }
                    result = invocation.invoke();
                } finally {
                    SQLExecuteContext.close();
                }

                return result;
            }
        };
        List<Invoker> invokers = new LinkedList<>();
        invokers.add(invoker);

        Object proxy = proxyFactory.createProxy(mapperProxy.getClass() , null , constructorArgTypes , constructorArgs , invokers , null);
        Reflect.on(mapper).set("h" , proxy);
        postProcessObject.setMapper(mapper);
        return postProcessObject;
    }

    private <T> void cacheDataSourceRoute(Class<T> type) {
        List<Method> methods = MethodUtils.getMethodsListWithAnnotation(type, DataSourceRoute.class);
        if (CollectionUtils.isNotEmpty(methods)) {
            for (Method method : methods) {
                DataSourceRoute dataSourceRoute = method.getAnnotation(DataSourceRoute.class);
                DATASOURCE_ROUTE_MAP.put(method , dataSourceRoute);
            }
        }
    }
}
